package accounts

import (
	"notabug.org/apiote/amuse/db"

	"bytes"
	"encoding/base64"
	"errors"
	"fmt"
	"strings"

	"github.com/pquerna/otp/totp"
	"golang.org/x/crypto/argon2"
	"notabug.org/apiote/gott"
)

type AuthData struct {
	username   string
	password   string
	sfa        string
	remember   bool
}

type AuthResult struct {
	user             *db.User
	passwordHash     string
	sfaSecret        string
	recoveryCodesRaw string
	recoveryCodes    []string
	token            string
}

type Argon struct {
	password string
	argon    string
	parts    []string
	memory   uint32
	time     uint32
	threads  uint8
	salt     []byte
	hash     []byte
	keyLen   uint32
}

func findUser(args ...interface{}) (interface{}, error) {
	authData := args[0].(*AuthData)
	authResult := args[1].(*AuthResult)
	user, err := db.GetUser(authData.username)
	authResult.user = user
	if empty, ok := err.(db.EmptyError); ok {
		err = AuthError{Err: empty}
	}
	return gott.Tuple(args), err
}

func unmarshalUser(args ...interface{}) interface{} {
	authResult := args[1].(*AuthResult)
	authResult.passwordHash = authResult.user.PasswordHash
	authResult.sfaSecret = authResult.user.Sfa
	authResult.recoveryCodesRaw = authResult.user.RecoveryCodes
	authResult.recoveryCodes = strings.Split(authResult.recoveryCodesRaw, ",")
	return gott.Tuple(args)
}

func splitArgon(args ...interface{}) interface{} {
	argon := args[0].(*Argon)
	argon.parts = strings.Split(argon.argon, "$")
	return gott.Tuple(args)
}

func decodeArgonParams(args ...interface{}) (interface{}, error) {
	argon := args[0].(*Argon)
	_, err := fmt.Sscanf(argon.parts[3], "m=%d,t=%d,p=%d", &argon.memory,
		&argon.time, &argon.threads)
	return gott.Tuple(args), err
}

func decodeSalt(args ...interface{}) (interface{}, error) {
	argon := args[0].(*Argon)
	salt, err := base64.RawStdEncoding.DecodeString(argon.parts[4])
	argon.salt = salt
	return gott.Tuple(args), err
}

func decodeHash(args ...interface{}) (interface{}, error) {
	argon := args[0].(*Argon)
	hash, err := base64.RawStdEncoding.DecodeString(argon.parts[5])
	argon.hash = hash
	argon.keyLen = uint32(len(hash))
	return gott.Tuple(args), err
}

func compareArgon(args ...interface{}) (interface{}, error) {
	argon := args[0].(*Argon)
	comparisonHash := argon2.IDKey([]byte(argon.password), argon.salt, argon.time,
		argon.memory, argon.threads, argon.keyLen)
	if bytes.Compare(comparisonHash, argon.hash) != 0 {
		return gott.Tuple(args), AuthError{Err: errors.New("Password does not match")}
	} else {
		return gott.Tuple(args), nil
	}
}

func checkPassword(args ...interface{}) (interface{}, error) {
	authData := args[0].(*AuthData)
	authResult := args[1].(*AuthResult)
	_, err := gott.
		NewResult(gott.Tuple{&Argon{argon: authResult.passwordHash,
			password: authData.password}}).
		Map(splitArgon).
		Bind(decodeArgonParams).
		Bind(decodeSalt).
		Bind(decodeHash).
		Bind(compareArgon).
		Finish()
	return gott.Tuple(args), err
}

func checkSfa(args ...interface{}) (interface{}, error) {
	authData := args[0].(*AuthData)
	authResult := args[1].(*AuthResult)
	if authResult.sfaSecret == "" {
		return gott.Tuple(args), nil
	}

	for i, code := range authResult.recoveryCodes {
		if authData.sfa == code {
			authResult.recoveryCodes = append(authResult.recoveryCodes[:i],
				authResult.recoveryCodes[i+1:]...)
			authResult.recoveryCodesRaw = strings.Join(authResult.recoveryCodes, ",")
			return gott.Tuple(args), nil
		}
	}

	authData.sfa = strings.ReplaceAll(authData.sfa, " ", "")
	if totp.Validate(authData.sfa, authResult.sfaSecret) {
		return gott.Tuple(args), nil
	}

	return gott.Tuple(args), AuthError{Err: errors.New("Wrong TOTP token")}
}

func updateSfa(args ...interface{}) (interface{}, error) {
	authData := args[0].(*AuthData)
	authResult := args[1].(*AuthResult)
	err := db.UpdateRecoveryCodes(authData.username, authResult.recoveryCodesRaw)
	return gott.Tuple(args), err
}

func createSession(args ...interface{}) (interface{}, error) {
	authData := args[0].(*AuthData)
	authResult := args[1].(*AuthResult)
	session, err := db.CreateSession(authData.username, false) // todo long session
	authResult.token = session.Id
	return gott.Tuple(args), err
}

func clearSessions(args ...interface{}) (interface{}, error) {
	result := args[1].(*AuthResult)
	err := db.ClearSessions(result.user.Username)
	return gott.Tuple(args), err
}

func Login(username, password, sfa string, remember bool) (string, error) {
	r, err := gott.
		NewResult(gott.Tuple{&AuthData{username: username, password: password,
			sfa: sfa, remember: remember}, &AuthResult{}}).
		Bind(findUser).
		Bind(clearSessions).
		Map(unmarshalUser).
		Bind(checkPassword).
		Bind(checkSfa).
		Bind(updateSfa).
		Bind(createSession).
		Finish()
	if err != nil {
		return "", err
	}
	return r.(gott.Tuple)[1].(*AuthResult).token, err
}
